Machine learning to segment neutron images

Anders Kaestner, Beamline scientist - Neutron Imaging

Laboratory for Neutron Scattering and Imaging
Paul Scherrer Institut

Lecture outline

  1. Introduction
  2. Limited data problem
  3. Unsupervised segmentation
  4. Supervised segmentation
  5. Final problem: Segmenting root networks using convolutional NNs
  6. Future Machine learning challenges in NI

Getting started

If you want to run the notebook on your own computer, you'll need to perform the following step:

  • You will need to install Anaconda
  • Clone the lecture repository (in the location you'd like to have it)
    git clone https://github.com/ImagingLectures/MLSegmentation4NI.git
    
  • Enter the folder 'MLSegmentation'
  • Create an environment for the notebook
    conda env create -f environment. yml -n MLSeg4NI
    
  • Enter the environment
    conda env activate MLSeg4NI
    

Importing needed modules

This lecture needs some modules to run. We import all of them here.

In [1]:
import matplotlib.pyplot as plt
import seaborn           as sn
import numpy             as np
import pandas            as pd
import skimage.filters   as flt
import skimage.io        as io
import matplotlib        as mpl

from sklearn.cluster     import KMeans
from sklearn.neighbors   import KNeighborsClassifier
from sklearn.metrics     import confusion_matrix
from sklearn.datasets    import make_blobs

from matplotlib.colors   import ListedColormap
from lecturesupport      import plotsupport as ps

import scipy.stats       as stats
import astropy.io.fits   as fits

from keras.models        import Model
from keras.layers        import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate

%matplotlib inline


from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'png')
#plt.style.use('seaborn')
mpl.rcParams['figure.dpi'] = 300
Using TensorFlow backend.
In [2]:
import importlib
importlib.reload(ps);

Introduction

  • Introduction to neutron imaging

    • Some words about the method
    • Contrasts
  • Introduction to segmentation

    • What is segmentation
    • Noise and SNR
  • Problematic segmentation tasks

    • Intro
    • Segmentation problems in neutron imaging

What is an image?

A very abstract definition:

  • A pairing between spatial information (position)
  • and some other kind of information (value).

In most cases this is a two- or three-dimensional position (x,y,z coordinates) and a numeric value (intensity)

Science and Imaging

Images are great for qualitative analyses since our brains can quickly interpret them without large programming investements.

Proper processing and quantitative analysis is however much more difficult with images.

  • If you measure a temperature, quantitative analysis is easy, $T=50K$.
  • If you measure an image it is much more difficult and much more prone to mistakes,
    • subtle setup variations may break you analysis process,
    • and confusing analyses due to unclear problem definition

Furthermore in image processing there is a plethora of tools available

  • Thousands of algorithms available
  • Thousands of tools
  • Many images require multi-step processing
  • Experimenting is time-consuming

Some word about neutron imaging

$$I=I_0\cdot{}e^{-\int_L \mu{}(x) dx}$$

Neutron imaging contrast


Transmission through sample X-ray attenuation Neutron attenuation

Measurements are rarely perfect

Factors affecting the image quality

  • Resolution (Imaging system transfer functions)
  • Noise
  • Contrast
  • Inhomogeneous contrast
  • Artifacts

Introduction to segmentation

Different types of segmentation

Basic segmentation: Applying a threshold to an image

Start out with a simple image of a cross with added noise

$$ I(x,y) = f(x,y) $$
In [3]:
fig,ax = plt.subplots(1,2,figsize=(12,6))
nx = 5; ny = 5;
xx, yy   = np.meshgrid(np.arange(-nx, nx+1)/nx*2*np.pi, np.arange(-ny, ny+1)/ny*2*np.pi)
cross_im = 1.5*np.abs(np.cos(xx*yy))/(np.abs(xx*yy)+(3*np.pi/nx)) + np.random.uniform(-0.25, 0.25, size = xx.shape)       

im=ax[0].imshow(cross_im, cmap = 'hot'); ax[0].set_title("Image")
ax[1].hist(cross_im.ravel(),bins=10); ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
2021-02-15T13:45:10.483136 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Applying a threshold to an image

Applying the threshold is a deceptively simple operation

$$ I(x,y) = \begin{cases} 1, & f(x,y)\geq0.40 \\ 0, & f(x,y)<0.40 \end{cases}$$
In [4]:
threshold = 0.4; thresh_img = cross_im > threshold
fig,ax = plt.subplots(1,2,figsize=(12,6))
ax[0].imshow(cross_im, cmap = 'hot', extent = [xx.min(), xx.max(), yy.min(), yy.max()]); ax[0].set_title("Image")
ax[0].plot(xx[np.where(thresh_img)]*0.9, yy[np.where(thresh_img)]*0.9,
           'ks', markerfacecolor = 'green', alpha = 0.5,label = 'Threshold', markersize = 22); ax[0].legend(fontsize=12);
ax[1].hist(cross_im.ravel(),bins=10); ax[1].axvline(x=threshold,color='r',label='Threshold'); ax[1].legend(fontsize=12); 
ax[1].set_xlabel('Gray value'); ax[1].set_ylabel('Counts'); ax[1].set_title("Histogram");
2021-02-15T13:45:11.513405 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Noise and SNR

The noise in neutron imaging mainly originates from the amount of captured neutrons.

This noise is Poisson distributed and the signal to noise ratio is

$$SNR=\frac{E[x]}{s[x]}\sim\frac{N}{\sqrt{N}}=\sqrt{N}$$

Problematic segmentation tasks

Woodland Encounter Bev Doolittle

Typical image features that makes life harder

Segmentation problems in neutron imaging

Limited data problem

Different types of limited data:

  • Few data points or limited amounts of images
  • Unbalanced data
  • Little or missing training data

Training data from NI is limited

  • Long experiment times
  • Few samples
  • Some recycling from previous experiments is posible.

Augmentation to increase training data

Data augmentation is a method modify your exisiting data to obtain variations of it.

Retial images from [DRIVE](https://drive.grand-challenge.org/DRIVE/) prepared by Gian Guido Parenza.

Augmentation will be used to increase the training data in the root segmenation example in the end of this lecture.

Simulation to increase training data

  • Geometric models
  • Template models
  • Physical models

Both augmented and simulated data should be combined with real data.

Transfer learning

Transfer learning is a technique that uses a pre-trained network to

  • Speed up training on your current data
  • Support in cases of limited data
  • Improve network performance

Unsupervised segmentation

Introducing clustering

In [5]:
test_pts = pd.DataFrame(make_blobs(n_samples=200, random_state=2018)[
                        0], columns=['x', 'y'])
plt.plot(test_pts.x, test_pts.y, 'r.');
2021-02-15T13:45:12.429226 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

k-means

Basic clustering example

In [6]:
fig, ax = plt.subplots(1,3,figsize=(15,4.5))

for i in range(3) :
    km = KMeans(n_clusters=i+2, random_state=2018); n_grp = km.fit_predict(test_pts)
    ax[i].scatter(test_pts.x, test_pts.y, c=n_grp)
    ax[i].set_title('{0} groups'.format(i+2))
2021-02-15T13:45:13.022297 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Add spatial information to k-means

When can clustering be used on images?

  • Single images
  • Bimodal data
  • Spectrum data

Clustering applied to wavelength resolved imaging

The imaging techniques and its applications

The data

In [7]:
tof  = np.load('../data/tofdata.npy')
wtof = tof.mean(axis=2)
plt.imshow(wtof,cmap='gray'); 
plt.title('Average intensity all time bins');
2021-02-15T13:45:13.822039 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Looking at the spectra

In [8]:
fig, ax= plt.subplots(1,2,figsize=(12,5))
ax[0].imshow(wtof,cmap='gray'); ax[0].set_title('Average intensity all time bins');
ax[0].plot(57,3,'ro'), ax[0].plot(15,30,'bo'), ax[0].plot(79,90,'go'); ax[0].plot(100,120,'co');
ax[1].plot(tof[30,15,:],'b', label='Sample'); ax[1].plot(tof[3,57,:],'r', label='Background'); ax[1].plot(tof[90,79,:],'g', label='Spacer'); ax[1].legend();ax[1].plot(tof[120,100,:],'c', label='Sample 2');
2021-02-15T13:45:14.404756 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Reshaping

In [9]:
tofr=tof.reshape([tof.shape[0]*tof.shape[1],tof.shape[2]])
print("Input ToF dimensions",tof.shape)
print("Reshaped ToF data",tofr.shape)
Input ToF dimensions (128, 128, 661)
Reshaped ToF data (16384, 661)

Setting up and running k-means

  • We can clearly see that there is void on the sides of the specimens.
  • There is also a separating band between the specimens.
  • Finally we have to decide how many regions we want to find in the specimens. Let's start with two regions with different characteristics.
In [10]:
km = KMeans(n_clusters=4, random_state=2018)
c  = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose()             # cluster centroid spectra

Results from the first try

In [11]:
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc);                  axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot

im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
2021-02-15T13:45:17.799173 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

We need more clusters

  • Experiment data has variations on places we didn't expect k-means to detect as clusters.
  • We need to increase the number of clusters!
In [12]:
km = KMeans(n_clusters=10, random_state=2018)
c  = km.fit_predict(tofr).reshape(tof.shape[:2]) # Label image
kc = km.cluster_centers_.transpose()             # cluster centroid spectra

Results of k-means with ten clusters

In [13]:
fig,axes = plt.subplots(1,3,figsize=(18,5)); axes=axes.ravel()
axes[0].imshow(wtof,cmap='viridis'); axes[0].set_title('Average image')
p=axes[1].plot(kc);                  axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
cmap=ps.buildCMap(p) # Create a color map with the same colors as the plot

im=axes[2].imshow(c,cmap=cmap); plt.colorbar(im);
axes[2].set_title('Cluster map');
plt.tight_layout()
2021-02-15T13:45:26.203436 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Interpreting the clusters

In [14]:
fig,axes = plt.subplots(1,2,figsize=(14,5)); axes=axes.ravel()
axes[0].matshow(np.corrcoef(kc.transpose()))
axes[1].plot(kc); axes[1].set_title('Cluster centroid spectra'); axes[1].set_aspect(tof.shape[2], adjustable='box')
2021-02-15T13:45:27.835346 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Cleaning up the works space

In [15]:
del km, c, kc, tofr, tof

Supervised segmentation

  1. Training: Requires training data
  2. Verification: Requires verification data
  3. Inference: The images you want to segment

k nearest neighbors

Create example data for supervised segmentation

In [16]:
blob_data, blob_labels = make_blobs(n_samples=100, random_state=2018)
test_pts = pd.DataFrame(blob_data, columns=['x', 'y'])
test_pts['group_id'] = blob_labels
plt.scatter(test_pts.x, test_pts.y, c=test_pts.group_id, cmap='viridis');
2021-02-15T13:45:28.641241 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Detecting unwanted outliers in neutron images

In [17]:
orig= fits.getdata('../data/spots/mixture12_00001.fits')
annotated=io.imread('../data/spots/mixture12_00001.png'); mask=(annotated[:,:,1]==0)
r=600; c=600; w=256
ps.magnifyRegion(orig,[r,c,r+w,c+w],[15,7],vmin=400,vmax=4000,title='Neutron radiography')
2021-02-15T13:45:29.449193 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Marked-up spots

Baseline - Traditional spot cleaning algorithm

Parameters

  • N Width of median filter.
  • k Threshold level for outlier detection.

The spot cleaning algorithm

In [18]:
def spotCleaner(img, threshold=0.95, selem=np.ones([3,3])) :
    fimg=img.astype('float32')
    mimg = flt.median(fimg,selem=selem)
    timg = threshold < np.abs(fimg-mimg)
    cleaned = mimg * timg + fimg * (1-timg)
    return (cleaned,timg)
In [19]:
baseclean,timg = spotCleaner(orig,threshold=1000)
ps.magnifyRegion(baseclean,[r,c,r+w,c+w],[12,3],vmin=400,vmax=4000,title='Cleaned image')
ps.magnifyRegion(timg,[r,c,r+w,c+w],[12,3],vmin=0,vmax=1,title='Detection image')
2021-02-15T13:45:33.589629 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/
2021-02-15T13:45:35.044066 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

k nearest neighbors to detect spots

In [20]:
selem=np.ones([3,3])
forig=orig.astype('float32')
mimg = flt.median(forig,selem=selem)
d = np.abs(forig-mimg)

fig,ax=plt.subplots(1,1,figsize=(8,5))
h,x,y,u=ax.hist2d(forig[:1024,:].ravel(),d[:1024,:].ravel(), bins=100);
ax.imshow(np.log(h[::-1]+1),vmin=0,vmax=3,extent=[x.min(),x.max(),y.min(),y.max()])
ax.set_xlabel('Input image - $f$'),ax.set_ylabel('$|f-med_{3x3}(f)|$'),ax.set_title('Log bivariate histogram');
2021-02-15T13:45:36.875814 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Prepare data

Training data

In [21]:
trainorig = forig[:,:1000].ravel()
traind    = d[:,:1000].ravel()
trainmask = mask[:,:1000].ravel()

train_pts = pd.DataFrame({'orig': trainorig, 'd': traind, 'mask':trainmask})

Test data

In [22]:
testorig = forig[:,1000:].ravel()
testd    = d[:,1000:].ravel()
testmask = mask[:,1000:].ravel()

test_pts = pd.DataFrame({'orig': testorig, 'd': testd, 'mask':testmask})

Train the model

In [23]:
k_class = KNeighborsClassifier(1)
k_class.fit(train_pts[['orig', 'd']], train_pts['mask']) 
Out[23]:
KNeighborsClassifier(n_neighbors=1)

Inspect decision space

In [24]:
xx, yy = np.meshgrid(np.linspace(test_pts.orig.min(), test_pts.orig.max(), 100),
                     np.linspace(test_pts.d.min(), test_pts.d.max(), 100),indexing='ij');
grid_pts = pd.DataFrame(dict(x=xx.ravel(), y=yy.ravel()))
grid_pts['predicted_id'] = k_class.predict(grid_pts[['x', 'y']])
plt.scatter(grid_pts.x, grid_pts.y, c=grid_pts.predicted_id, cmap='gray'); plt.title('Testing Points'); plt.axis('square');
2021-02-15T13:45:41.351363 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Apply knn to unseen data

In [25]:
pred = k_class.predict(test_pts[['orig', 'd']])
pimg = pred.reshape(d[1000:,:].shape)
In [26]:
fig,ax = plt.subplots(1,3,figsize=(15,6))
ax[0].imshow(forig[1000:,:],vmin=0,vmax=4000), ax[0].set_title('Original image')
ax[1].imshow(pimg), ax[1].set_title('Predicted spot')
ax[2].imshow(mask[1000:,:]),ax[2].set_title('Annotated spots');
2021-02-15T13:46:32.843560 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Performance check

In [27]:
cmbase = confusion_matrix(mask[:,1000:].ravel(), timg[:,1000:].ravel(), normalize='all')
cmknn  = confusion_matrix(mask[:,1000:].ravel(), pimg.ravel(), normalize='all')
In [28]:
fig,ax = plt.subplots(1,2,figsize=(10,4))
sn.heatmap(cmbase, annot=True,ax=ax[0]), ax[0].set_title('Confusion matrix baseline');
sn.heatmap(cmknn, annot=True,ax=ax[1]), ax[1].set_title('Confusion matrix k-NN');
2021-02-15T13:46:36.530414 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Some remarks about k-nn

  • It takes more time to process
  • You need to prepare training data
    • Annotation takes time...
    • Here we used the segmentation on the same type of image
    • We should normalize the data
    • This was a raw projection, what happens if we use a flat field corrected image?
  • Finds more spots than baseline
  • Data is very unbalanced, try a selection of non-spot data for training.
    • Is it faster?
    • Is there a drop segmentation performance?

Note There are other spot detection methods that perform better than the baseline.

Clean up

In [29]:
del k_class, cmbase, cmknn

Convolutional neural networks for segmentation

In [30]:
import keras.optimizers as opt
import keras.losses as loss
import keras.metrics as metrics

Training data

We have two choices:

  1. Use real data
    • requires time consuming markup to provide training data
    • corresponds to real life images
  2. Synthesize data
    • flexible and provides both 'dirty' data and ground truth.
    • model may not behave as real data

Preparing real data

We will use the spotty image as training data for this example

Prepare training, validation, and test data

Any analysis system must be verified to be demonstrate its performance and to further optimize it.

For this we need to split our data into three categories:

  1. Training data
  2. Test data
  3. Validation data
Training Validation Test
70% 15% 15%

Build a CNN for spot detection and cleaning

We need:

  • Data
  • Tensorflow
    • Data provider
    • Model design

Build a U-Net model

In [31]:
def buildSpotUNet( base_depth = 48) :
    in_img = Input((None, None, 1), name='Image_Input')
    lay_1 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(in_img)
    lay_2 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_1)
    lay_3 = MaxPooling2D(pool_size=(2, 2))(lay_2)
    lay_4 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_3)
    lay_5 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_4)
    lay_6 = MaxPooling2D(pool_size=(2, 2))(lay_5)
    lay_7 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_6)
    lay_8 = Conv2D(base_depth*4, kernel_size=(3, 3), padding='same',activation='relu')(lay_7)
    lay_9 = UpSampling2D((2, 2))(lay_8)
    lay_10 = concatenate([lay_5, lay_9])
    lay_11 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_10)
    lay_12 = Conv2D(base_depth*2, kernel_size=(3, 3), padding='same',activation='relu')(lay_11)
    lay_13 = UpSampling2D((2, 2))(lay_12)
    lay_14 = concatenate([lay_2, lay_13])
    lay_15 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_14)
    lay_16 = Conv2D(base_depth, kernel_size=(3, 3), padding='same',activation='relu')(lay_15)
    lay_17 = Conv2D(1, kernel_size=(1, 1), padding='same',
                    activation='relu')(lay_16)
    t_unet = Model(inputs=[in_img], outputs=[lay_17], name='SpotUNET')
    return t_unet

Model summary

In [32]:
t_unet = buildSpotUNet(base_depth=24)
t_unet.summary()
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4070: The name tf.nn.max_pool is deprecated. Please use tf.nn.max_pool2d instead.

Model: "SpotUNET"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Image_Input (InputLayer)        (None, None, None, 1 0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 2 240         Image_Input[0][0]                
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, None, None, 2 5208        conv2d_1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, None, None, 2 0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, None, None, 4 10416       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, None, None, 4 20784       conv2d_3[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, None, None, 4 0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, None, None, 9 41568       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, None, None, 9 83040       conv2d_5[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)  (None, None, None, 9 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, None, None, 1 0           conv2d_4[0][0]                   
                                                                 up_sampling2d_1[0][0]            
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, None, None, 4 62256       concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, None, None, 4 20784       conv2d_7[0][0]                   
__________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)  (None, None, None, 4 0           conv2d_8[0][0]                   
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, None, None, 7 0           conv2d_2[0][0]                   
                                                                 up_sampling2d_2[0][0]            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, None, None, 2 15576       concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, None, None, 2 5208        conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, None, None, 1 25          conv2d_10[0][0]                  
==================================================================================================
Total params: 265,105
Trainable params: 265,105
Non-trainable params: 0
__________________________________________________________________________________________________

Prepare data for training and validation

In [33]:
train_img,  valid_img  = forig[128:256, 500:1300], forig[500:1000, 300:1500]
train_mask, valid_mask = mask[128:256, 500:1300], mask[500:1000, 300:1500]
wpos = [600,600]; ww   = 512
forigc = forig[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]
maskc  = mask[wpos[0]:(wpos[0]+ww),wpos[1]:(wpos[1]+ww)]

# train_img, valid_img = forig[128:256, 300:1500], forig[500:, 300:1500]
# train_mask, valid_mask = mask[128:256, 300:1500], mask[500:, 300:1500]
fig, ax = plt.subplots(1, 4, figsize=(15, 6), dpi=300); ax=ax.ravel()

ax[0].imshow(train_img, cmap='bone',vmin=0,vmax=4000);ax[0].set_title('Train Image')
ax[1].imshow(train_mask, cmap='bone'); ax[1].set_title('Train Mask')

ax[2].imshow(valid_img, cmap='bone',vmin=0,vmax=4000); ax[2].set_title('Validation Image')
ax[3].imshow(valid_mask, cmap='bone');ax[3].set_title('Validation Mask');
2021-02-15T13:46:37.293746 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Functions to prepare data for training

In [34]:
def prep_img(x, n=1): 
    return (prep_mask(x, n=n)-train_img.mean())/train_img.std()


def prep_mask(x, n=1): 
    return np.stack([np.expand_dims(x, -1)]*n, 0)

Test the untrained model

  • We can make predictions with an untrained model (default parameters)
  • but we clearly do not expect them to be very good
In [35]:
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:422: The name tf.global_variables is deprecated. Please use tf.compat.v1.global_variables instead.

In [36]:
fig, m_axs = plt.subplots(2, 3, figsize=(15, 6), dpi=150)
for c_ax in m_axs.ravel():
    c_ax.axis('off')
((ax1, _, ax2), (ax3, ax4, ax5)) = m_axs
ax1.imshow(train_img, cmap='bone',vmin=0,vmax=4000); ax1.set_title('Train Image')
ax2.imshow(train_mask, cmap='viridis'); ax2.set_title('Train Mask')

ax3.imshow(forigc, cmap='bone',vmin=0, vmax=4000); ax3.set_title('Test Image')
ax4.imshow(unet_pred, cmap='viridis', vmin=0, vmax=1); ax4.set_title('Predicted Segmentation')

ax5.imshow(maskc, cmap='viridis'); ax5.set_title('Ground Truth');
2021-02-15T13:46:39.586315 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Training conditions

  • Loss function - Binary cross-correlation
  • Optimizer - ADAM
  • 20 Epochs (training iterations)
  • Metrics
    1. Binary accuracy (percentage of pixels correct classified) $$BA=\frac{1}{N}\sum_i(f_i==g_i)$$
    2. Mean absolute error

Another popular metric is the Dice score $$DSC=\frac{2|X \cap Y|}{|X|+|Y|}=\frac{2\,TP}{2TP+FP+FN}$$

In [37]:
mlist = [
      metrics.TruePositives(name='tp'),        metrics.FalsePositives(name='fp'), 
      metrics.TrueNegatives(name='tn'),        metrics.FalseNegatives(name='fn'), 
      metrics.BinaryAccuracy(name='accuracy'), metrics.Precision(name='precision'),
      metrics.Recall(name='recall'),           metrics.AUC(name='auc'),
      metrics.MeanAbsoluteError(name='mae')]

t_unet.compile(
    loss=loss.BinaryCrossentropy(),  # we use the binary cross-entropy to optimize
    optimizer=opt.Adam(lr=1e-3),     # we use ADAM to optimize
    metrics=mlist                    # we keep track of the metrics in mlist
)
WARNING:tensorflow:From /home/travis/miniconda/envs/book/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:3172: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

A general note on the following demo

This is a very bad way to train a model;

  • the loss function is poorly chosen,
  • the optimizer can be improved the learning rate can be changed,
  • the training and validation data should not come from the same sample (and definitely not the same measurement).

The goal is to be aware of these techniques and have a feeling for how they can work for complex problems

Training the spot detection model

In [38]:
loss_history = t_unet.fit(prep_img(train_img, n=3),
                          prep_mask(train_mask, n=3),
                          validation_data=(prep_img(valid_img),
                                           prep_mask(valid_mask)),
                          epochs=20,
                          verbose = 1)
Train on 3 samples, validate on 1 samples
Epoch 1/20
3/3 [==============================] - 10s 3s/step - loss: 0.0936 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.5055 - mae: 0.0159 - val_loss: 0.0716 - val_tp: 3.0000 - val_fp: 6.0000 - val_tn: 593510.0000 - val_fn: 6481.0000 - val_accuracy: 0.9892 - val_precision: 0.3333 - val_recall: 4.6268e-04 - val_auc: 0.7423 - val_mae: 0.0153
Epoch 2/20
3/3 [==============================] - 7s 2s/step - loss: 0.0545 - tp: 0.0000e+00 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2544.0000 - accuracy: 0.9917 - precision: 0.0000e+00 - recall: 0.0000e+00 - auc: 0.7486 - mae: 0.0139 - val_loss: 0.0594 - val_tp: 12.0000 - val_fp: 14.0000 - val_tn: 593502.0000 - val_fn: 6472.0000 - val_accuracy: 0.9892 - val_precision: 0.4615 - val_recall: 0.0019 - val_auc: 0.7911 - val_mae: 0.0335
Epoch 3/20
3/3 [==============================] - 7s 2s/step - loss: 0.0514 - tp: 3.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2541.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0012 - auc: 0.8261 - mae: 0.0319 - val_loss: 0.0685 - val_tp: 16.0000 - val_fp: 10.0000 - val_tn: 593506.0000 - val_fn: 6468.0000 - val_accuracy: 0.9892 - val_precision: 0.6154 - val_recall: 0.0025 - val_auc: 0.7511 - val_mae: 0.0164
Epoch 4/20
3/3 [==============================] - 7s 2s/step - loss: 0.0622 - tp: 9.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2535.0000 - accuracy: 0.9917 - precision: 1.0000 - recall: 0.0035 - auc: 0.7204 - mae: 0.0132 - val_loss: 0.0602 - val_tp: 24.0000 - val_fp: 19.0000 - val_tn: 593497.0000 - val_fn: 6460.0000 - val_accuracy: 0.9892 - val_precision: 0.5581 - val_recall: 0.0037 - val_auc: 0.8607 - val_mae: 0.0130
Epoch 5/20
3/3 [==============================] - 7s 2s/step - loss: 0.0545 - tp: 18.0000 - fp: 0.0000e+00 - tn: 304656.0000 - fn: 2526.0000 - accuracy: 0.9918 - precision: 1.0000 - recall: 0.0071 - auc: 0.8131 - mae: 0.0105 - val_loss: 0.0447 - val_tp: 42.0000 - val_fp: 50.0000 - val_tn: 593466.0000 - val_fn: 6442.0000 - val_accuracy: 0.9892 - val_precision: 0.4565 - val_recall: 0.0065 - val_auc: 0.9338 - val_mae: 0.0199
Epoch 6/20
3/3 [==============================] - 7s 2s/step - loss: 0.0361 - tp: 24.0000 - fp: 3.0000 - tn: 304653.0000 - fn: 2520.0000 - accuracy: 0.9918 - precision: 0.8889 - recall: 0.0094 - auc: 0.9250 - mae: 0.0166 - val_loss: 0.0441 - val_tp: 47.0000 - val_fp: 54.0000 - val_tn: 593462.0000 - val_fn: 6437.0000 - val_accuracy: 0.9892 - val_precision: 0.4653 - val_recall: 0.0072 - val_auc: 0.9319 - val_mae: 0.0183
Epoch 7/20
3/3 [==============================] - 7s 2s/step - loss: 0.0372 - tp: 27.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2517.0000 - accuracy: 0.9918 - precision: 0.7500 - recall: 0.0106 - auc: 0.9185 - mae: 0.0184 - val_loss: 0.0463 - val_tp: 50.0000 - val_fp: 50.0000 - val_tn: 593466.0000 - val_fn: 6434.0000 - val_accuracy: 0.9892 - val_precision: 0.5000 - val_recall: 0.0077 - val_auc: 0.9219 - val_mae: 0.0146
Epoch 8/20
3/3 [==============================] - 7s 2s/step - loss: 0.0343 - tp: 27.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2517.0000 - accuracy: 0.9918 - precision: 0.7500 - recall: 0.0106 - auc: 0.9329 - mae: 0.0129 - val_loss: 0.0466 - val_tp: 55.0000 - val_fp: 48.0000 - val_tn: 593468.0000 - val_fn: 6429.0000 - val_accuracy: 0.9892 - val_precision: 0.5340 - val_recall: 0.0085 - val_auc: 0.9266 - val_mae: 0.0123
Epoch 9/20
3/3 [==============================] - 7s 2s/step - loss: 0.0348 - tp: 36.0000 - fp: 9.0000 - tn: 304647.0000 - fn: 2508.0000 - accuracy: 0.9918 - precision: 0.8000 - recall: 0.0142 - auc: 0.9309 - mae: 0.0097 - val_loss: 0.0421 - val_tp: 71.0000 - val_fp: 56.0000 - val_tn: 593460.0000 - val_fn: 6413.0000 - val_accuracy: 0.9892 - val_precision: 0.5591 - val_recall: 0.0110 - val_auc: 0.9434 - val_mae: 0.0136
Epoch 10/20
3/3 [==============================] - 7s 2s/step - loss: 0.0327 - tp: 60.0000 - fp: 12.0000 - tn: 304644.0000 - fn: 2484.0000 - accuracy: 0.9919 - precision: 0.8333 - recall: 0.0236 - auc: 0.9435 - mae: 0.0108 - val_loss: 0.0401 - val_tp: 97.0000 - val_fp: 82.0000 - val_tn: 593434.0000 - val_fn: 6387.0000 - val_accuracy: 0.9892 - val_precision: 0.5419 - val_recall: 0.0150 - val_auc: 0.9567 - val_mae: 0.0174
Epoch 11/20
3/3 [==============================] - 7s 2s/step - loss: 0.0312 - tp: 72.0000 - fp: 30.0000 - tn: 304626.0000 - fn: 2472.0000 - accuracy: 0.9919 - precision: 0.7059 - recall: 0.0283 - auc: 0.9574 - mae: 0.0146 - val_loss: 0.0396 - val_tp: 108.0000 - val_fp: 82.0000 - val_tn: 593434.0000 - val_fn: 6376.0000 - val_accuracy: 0.9892 - val_precision: 0.5684 - val_recall: 0.0167 - val_auc: 0.9533 - val_mae: 0.0127
Epoch 12/20
3/3 [==============================] - 7s 2s/step - loss: 0.0295 - tp: 75.0000 - fp: 36.0000 - tn: 304620.0000 - fn: 2469.0000 - accuracy: 0.9918 - precision: 0.6757 - recall: 0.0295 - auc: 0.9546 - mae: 0.0101 - val_loss: 0.0365 - val_tp: 149.0000 - val_fp: 96.0000 - val_tn: 593420.0000 - val_fn: 6335.0000 - val_accuracy: 0.9893 - val_precision: 0.6082 - val_recall: 0.0230 - val_auc: 0.9649 - val_mae: 0.0133
Epoch 13/20
3/3 [==============================] - 7s 2s/step - loss: 0.0273 - tp: 114.0000 - fp: 54.0000 - tn: 304602.0000 - fn: 2430.0000 - accuracy: 0.9919 - precision: 0.6786 - recall: 0.0448 - auc: 0.9680 - mae: 0.0107 - val_loss: 0.0393 - val_tp: 206.0000 - val_fp: 148.0000 - val_tn: 593368.0000 - val_fn: 6278.0000 - val_accuracy: 0.9893 - val_precision: 0.5819 - val_recall: 0.0318 - val_auc: 0.9756 - val_mae: 0.0218
Epoch 14/20
3/3 [==============================] - 7s 2s/step - loss: 0.0339 - tp: 159.0000 - fp: 93.0000 - tn: 304563.0000 - fn: 2385.0000 - accuracy: 0.9919 - precision: 0.6310 - recall: 0.0625 - auc: 0.9741 - mae: 0.0213 - val_loss: 0.0376 - val_tp: 177.0000 - val_fp: 102.0000 - val_tn: 593414.0000 - val_fn: 6307.0000 - val_accuracy: 0.9893 - val_precision: 0.6344 - val_recall: 0.0273 - val_auc: 0.9585 - val_mae: 0.0120
Epoch 15/20
3/3 [==============================] - 7s 2s/step - loss: 0.0296 - tp: 132.0000 - fp: 69.0000 - tn: 304587.0000 - fn: 2412.0000 - accuracy: 0.9919 - precision: 0.6567 - recall: 0.0519 - auc: 0.9548 - mae: 0.0096 - val_loss: 0.0433 - val_tp: 165.0000 - val_fp: 85.0000 - val_tn: 593431.0000 - val_fn: 6319.0000 - val_accuracy: 0.9893 - val_precision: 0.6600 - val_recall: 0.0254 - val_auc: 0.9356 - val_mae: 0.0114
Epoch 16/20
3/3 [==============================] - 7s 2s/step - loss: 0.0346 - tp: 129.0000 - fp: 54.0000 - tn: 304602.0000 - fn: 2415.0000 - accuracy: 0.9920 - precision: 0.7049 - recall: 0.0507 - auc: 0.9272 - mae: 0.0091 - val_loss: 0.0432 - val_tp: 169.0000 - val_fp: 84.0000 - val_tn: 593432.0000 - val_fn: 6315.0000 - val_accuracy: 0.9893 - val_precision: 0.6680 - val_recall: 0.0261 - val_auc: 0.9349 - val_mae: 0.0113
Epoch 17/20
3/3 [==============================] - 7s 2s/step - loss: 0.0342 - tp: 132.0000 - fp: 54.0000 - tn: 304602.0000 - fn: 2412.0000 - accuracy: 0.9920 - precision: 0.7097 - recall: 0.0519 - auc: 0.9257 - mae: 0.0090 - val_loss: 0.0363 - val_tp: 216.0000 - val_fp: 110.0000 - val_tn: 593406.0000 - val_fn: 6268.0000 - val_accuracy: 0.9894 - val_precision: 0.6626 - val_recall: 0.0333 - val_auc: 0.9623 - val_mae: 0.0125
Epoch 18/20
3/3 [==============================] - 7s 2s/step - loss: 0.0297 - tp: 150.0000 - fp: 81.0000 - tn: 304575.0000 - fn: 2394.0000 - accuracy: 0.9919 - precision: 0.6494 - recall: 0.0590 - auc: 0.9514 - mae: 0.0101 - val_loss: 0.0375 - val_tp: 263.0000 - val_fp: 142.0000 - val_tn: 593374.0000 - val_fn: 6221.0000 - val_accuracy: 0.9894 - val_precision: 0.6494 - val_recall: 0.0406 - val_auc: 0.9667 - val_mae: 0.0188
Epoch 19/20
3/3 [==============================] - 7s 2s/step - loss: 0.0299 - tp: 186.0000 - fp: 99.0000 - tn: 304557.0000 - fn: 2358.0000 - accuracy: 0.9920 - precision: 0.6526 - recall: 0.0731 - auc: 0.9610 - mae: 0.0153 - val_loss: 0.0410 - val_tp: 314.0000 - val_fp: 173.0000 - val_tn: 593343.0000 - val_fn: 6170.0000 - val_accuracy: 0.9894 - val_precision: 0.6448 - val_recall: 0.0484 - val_auc: 0.9669 - val_mae: 0.0243
Epoch 20/20
3/3 [==============================] - 7s 2s/step - loss: 0.0336 - tp: 204.0000 - fp: 114.0000 - tn: 304542.0000 - fn: 2340.0000 - accuracy: 0.9920 - precision: 0.6415 - recall: 0.0802 - auc: 0.9638 - mae: 0.0206 - val_loss: 0.0391 - val_tp: 326.0000 - val_fp: 183.0000 - val_tn: 593333.0000 - val_fn: 6158.0000 - val_accuracy: 0.9894 - val_precision: 0.6405 - val_recall: 0.0503 - val_auc: 0.9711 - val_mae: 0.0224

Training history plots

In [39]:
titleDict = {'tp': "True Positives",'fp': "False Positives",'tn': "True Negatives",'fn': "False Negatives", 'accuracy':"BinaryAccuracy",'precision': "Precision",'recall':"Recall",'auc': "Area under Curve", 'mae': "Mean absolute error"}

fig,ax = plt.subplots(2,5, figsize=(20,8), dpi=300)
ax =ax.ravel()
for idx,key in enumerate(titleDict.keys()): 
    ax[idx].plot(loss_history.epoch, loss_history.history[key], color='coral', label='Training')
    ax[idx].plot(loss_history.epoch, loss_history.history['val_'+key], color='cornflowerblue', label='Validation')
    ax[idx].set_title(titleDict[key]); 

ax[9].axis('off');
axLine, axLabel = ax[0].get_legend_handles_labels() # Take the lables and plot line information from the first panel
lines =[]; labels = []; lines.extend(axLine); labels.extend(axLabel);fig.legend(lines, labels, bbox_to_anchor=(0.7, 0.3), loc='upper left');
2021-02-15T13:49:02.411236 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Prediction on the training data

In [40]:
unet_train_pred = t_unet.predict(prep_img(train_img[:,wpos[1]:(wpos[1]+ww)]))[0, :, :, 0]

fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs= m_axs.ravel(); 
for c_ax in m_axs: c_ax.axis('off')

m_axs[0].imshow(train_img[:,wpos[1]:(wpos[1]+ww)], cmap='bone', vmin=0, vmax=4000), m_axs[0].set_title('Train Image')
m_axs[1].imshow(unet_train_pred, cmap='viridis', vmin=0, vmax=0.2), m_axs[1].set_title('Predicted Training')
m_axs[2].imshow(train_mask[:,wpos[1]:(wpos[1]+ww)], cmap='viridis'), m_axs[2].set_title('Train Mask');
2021-02-15T13:49:04.347955 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Prediction using unseen data

In [41]:
unet_pred = t_unet.predict(prep_img(forigc))[0, :, :, 0]

fig, m_axs = plt.subplots(1, 3, figsize=(18, 4), dpi=150); m_axs = m_axs.ravel() ; 
for c_ax in m_axs: c_ax.axis('off')
m_axs[0].imshow(forigc, cmap='bone', vmin=0, vmax=4000); m_axs[0].set_title('Full Image')
f1=m_axs[1].imshow(unet_pred, cmap='viridis', vmin=0, vmax=0.1); m_axs[1].set_title('Predicted Segmentation'); fig.colorbar(f1,ax=m_axs[1]);
m_axs[2].imshow(maskc,cmap='viridis'); m_axs[2].set_title('Ground Truth');
2021-02-15T13:49:05.394484 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Converting predictions to segments

In [42]:
fig, ax = plt.subplots(1,2, figsize=(12,4))
ax0=ax[0].imshow(unet_pred, vmin=0, vmax=0.1); ax[0].set_title('Predicted segmentation'); fig.colorbar(ax0,ax=ax[0])
ax[1].imshow(0.05<unet_pred), ax[1].set_title('Final segmenation');
2021-02-15T13:49:06.361402 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Hit cases

In [43]:
gt = maskc
pr = 0.05<unet_pred
ps.showHitCases(gt,pr,cmap='gray')
2021-02-15T13:49:08.495917 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Hit map

In [44]:
fig, ax = plt.subplots(1,2,figsize=(12,4))

ps.showHitMap(gt,pr,ax=ax)
2021-02-15T13:49:10.240969 image/svg+xml Matplotlib v3.3.2, https://matplotlib.org/

Concluding remarks about the spot detection

Segmenting root networks in the rhizosphere using an U-Net

Background

  • Soil and in particular the rhizosphere are of central interest for neutron imaging users.
  • The experiments aim to follow the water distribution near the roots.
  • The roots must be identified in 2D and 3D data
  • Today: much of this mark-up is done manually!

Available data

Considered NN models

Loss functions

Training

Results

Summary

Future Machine learning challenges in neutron imaging

Concluding remarks

In [ ]: